"""
Poisson solver for the kernel‑to‑metric simulation.

We solve a discrete Poisson equation on a square lattice using a
convolution with a regularised 3‑D Green’s function ``1/sqrt(r^2 + eps^2)``.
The convolution is performed via the Fast Fourier Transform (FFT) with
zero‑padding to avoid wrap‑around artefacts.  This implements an
aperiodic convolution: the charge density is padded into a larger
array, and the kernel is centred and padded to the same size.  The
result is cropped back to the original lattice size.

The module exposes a high‑level interface ``compute_potential`` which
multiplies the source density by the coupling ``lambda`` and returns
both the potential and its gradient.  The gradient is computed by
finite differences.  All arrays are NumPy ``float64``.
"""

from __future__ import annotations

import numpy as np
from typing import Tuple


def _build_green_kernel(pad_L: int, eps: float) -> np.ndarray:
    """Construct a regularised 3‑D Green’s function kernel.

    The kernel has shape ``(pad_L, pad_L)`` and is defined by
    ``G(x,y) = 1/sqrt(x^2 + y^2 + eps^2)``.  The kernel is centred
    around the origin (0,0) and subsequently shifted into the top‑left
    corner for FFT convolution using ``np.fft.ifftshift``.

    Parameters
    ----------
    pad_L : int
        Side length of the padded arrays.  Must be at least twice the
        original lattice size to avoid circular wrap‑around.
    eps : float
        Regularisation parameter to avoid divergence at the origin.  Must
        be positive.

    Returns
    -------
    kernel_shifted : ndarray
        The FFT‑shifted kernel ready for convolution via FFT.
    """
    coords = np.arange(pad_L, dtype=float) - (pad_L // 2)
    x, y = np.meshgrid(coords, coords, indexing="ij")
    r2 = x**2 + y**2
    kernel = 1.0 / np.sqrt(r2 + eps**2)
    # Shift the kernel so that its zero‑frequency component is at [0,0]
    kernel_shifted = np.fft.ifftshift(kernel)
    return kernel_shifted


def aperiodic_convolution(source: np.ndarray,
                          ell: int,
                          epsilon_factor: float = 0.5) -> np.ndarray:
    """Convolve a source field with the regularised 3‑D Green’s function.

    The source field is assumed to be a square array of shape ``(L, L)``.
    It is zero‑padded to size ``(2L, 2L)`` and convolved with a kernel
    ``1/sqrt(r^2 + eps^2)`` using FFTs.  The convolution result is
    cropped back to the top‑left ``(L, L)`` region.

    Parameters
    ----------
    source : ndarray
        Source density (``L × L``) to be convolved with the Green’s function.
    ell : int
        Smoothing width used to set the regularisation ``eps``.  The
        parameter ``eps`` is ``epsilon_factor * ell``.  If ``ell`` is
        zero or negative, ``eps`` defaults to ``1e-6`` to avoid
        singularities.
    epsilon_factor : float, optional
        Multiplicative factor for ``ell`` when computing the
        regularisation parameter.  Defaults to 0.5.

    Returns
    -------
    V : ndarray
        Potential field (``L × L``) resulting from the convolution.
    """
    L = source.shape[0]
    pad_L = int(2 * L)
    # Determine regularisation; avoid zero eps
    eps = max(epsilon_factor * float(ell), 1e-6)
    # Build kernel once
    kernel_shifted = _build_green_kernel(pad_L, eps)
    # Pad source into top‑left corner of a larger array
    S_pad = np.zeros((pad_L, pad_L), dtype=float)
    S_pad[:L, :L] = source
    # Compute FFTs
    F_S = np.fft.fft2(S_pad)
    F_K = np.fft.fft2(kernel_shifted)
    # Convolution theorem
    conv = np.fft.ifft2(F_S * F_K).real
    # Crop back to original shape
    V = conv[:L, :L]
    return V


def compute_potential(G_hat: np.ndarray,
                      lambd: float,
                      ell: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Compute potential and field from a normalised envelope.

    Given a normalised gradient magnitude ``G_hat`` and a coupling
    constant ``lambda``, this function constructs a source density
    ``S = lambda · G_hat`` and computes the potential via an aperiodic
    convolution with the Green’s function.  The electric field is
    obtained as the negative gradient of the potential.  Finite
    differences are used for the gradient.

    Parameters
    ----------
    G_hat : ndarray
        Normalised gradient magnitude (``L × L``).
    lambd : float
        Coupling coefficient that scales the source density.
    ell : int
        Smoothing width used to set the Green’s function regularisation.

    Returns
    -------
    V : ndarray
        Potential field (``L × L``).
    E_x : ndarray
        Electric field component along the first axis (``L × L``).
    E_y : ndarray
        Electric field component along the second axis (``L × L``).
    """
    # Build source density
    S = lambd * G_hat.astype(float)
    # Solve Poisson equation via convolution
    V = aperiodic_convolution(S, ell)
    # Compute electric field as negative gradient of V
    dV_dx, dV_dy = np.gradient(V)
    E_x = -dV_dx
    E_y = -dV_dy
    return V, E_x, E_y
